import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython import embed

from geotransformer.modules.ops import point_to_node_partition, index_select
from geotransformer.modules.registration import get_node_correspondences
from geotransformer.modules.sinkhorn import LearnableLogOptimalTransport
from geotransformer.modules.geotransformer import (
    GeometricTransformer,
    SuperPointMatching,
    SuperPointTargetGenerator,
    LocalGlobalRegistration,
)

# 
from geotransformer.modules.threedgs.gauss import (
    GS_Parametrization,
)
from geotransformer.modules.threedgs.gauss_match import (
    GaussianCoarseRegistration,
    CoarseMatchFusion,

)

# end

from backbone import KPConvFPN

import os

def save_matching_results_txt(ref_node_corr_indices_1, src_node_corr_indices_1, node_corr_scores_1,
                               ref_node_corr_indices_2, src_node_corr_indices_2, node_corr_scores_2,
                               output_dir):
    os.makedirs(output_dir, exist_ok=True)
    # 
    with open(f"{output_dir}/ref_node_corr_indices_1.txt", 'w') as f:
        for idx in ref_node_corr_indices_1.cpu().numpy():
            f.write(f"{idx}\n")

    with open(f"{output_dir}/src_node_corr_indices_1.txt", 'w') as f:
        for idx in src_node_corr_indices_1.cpu().numpy():
            f.write(f"{idx}\n")

    with open(f"{output_dir}/node_corr_scores_1.txt", 'w') as f:
        for score in node_corr_scores_1.cpu().numpy():
            f.write(f"{score}\n")

    with open(f"{output_dir}/ref_node_corr_indices_2.txt", 'w') as f:
        for idx in ref_node_corr_indices_2.cpu().numpy():
            f.write(f"{idx}\n")

    with open(f"{output_dir}/src_node_corr_indices_2.txt", 'w') as f:
        for idx in src_node_corr_indices_2.cpu().numpy():
            f.write(f"{idx}\n")

    with open(f"{output_dir}/node_corr_scores_2.txt", 'w') as f:
        for score in node_corr_scores_2.cpu().numpy():
            f.write(f"{score}\n")


    print(f"Matching results saved to {output_dir}")


def update_by_frequency(ref_node_corr_indices_1, src_node_corr_indices_1, ref_node_corr_indices_2, src_node_corr_indices_2, top_k=256):
    """

    """
    # 
    ref_unique, ref_counts = torch.unique(ref_node_corr_indices_2, return_counts=True)
    src_unique, src_counts = torch.unique(src_node_corr_indices_2, return_counts=True)

    # 
    ref_max_count_indices = ref_unique[torch.argsort(ref_counts, descending=True)]
    src_max_count_indices = src_unique[torch.argsort(src_counts, descending=True)]

    # 
    updated_ref_node_corr_indices_1 = ref_node_corr_indices_1.clone()
    updated_src_node_corr_indices_1 = src_node_corr_indices_1.clone()

    # 
    for ref_idx in ref_max_count_indices:
        #
        updated_ref_node_corr_indices_1 = torch.where(updated_ref_node_corr_indices_1 == ref_idx, ref_idx, updated_ref_node_corr_indices_1)
    
    for src_idx in src_max_count_indices:
        # 
        updated_src_node_corr_indices_1 = torch.where(updated_src_node_corr_indices_1 == src_idx, src_idx, updated_src_node_corr_indices_1)

    #
    if updated_ref_node_corr_indices_1.size(0) > top_k:
        updated_ref_node_corr_indices_1 = updated_ref_node_corr_indices_1[:top_k]
        updated_src_node_corr_indices_1 = updated_src_node_corr_indices_1[:top_k]

    return updated_ref_node_corr_indices_1, updated_src_node_corr_indices_1




class GeoTransformer(nn.Module):
    def __init__(self, cfg):
        super(GeoTransformer, self).__init__()
        self.num_points_in_patch = cfg.model.num_points_in_patch
        self.matching_radius = cfg.model.ground_truth_matching_radius

         # 
        self.gs_parametrization = GS_Parametrization(num_points_in_patch=cfg.model.num_points_in_patch)
        # end 

        # 
        # 7 parameters
        self.backbone = KPConvFPN(
            cfg.backbone.input_dim,
            cfg.backbone.output_dim,
            cfg.backbone.init_dim,
            cfg.backbone.kernel_size,
            cfg.backbone.init_radius,
            cfg.backbone.init_sigma,
            cfg.backbone.group_norm,
        )

        # 

        # 
        self.transformer = GeometricTransformer(
            cfg.geotransformer.input_dim,
            cfg.geotransformer.output_dim,
            cfg.geotransformer.hidden_dim,
            cfg.geotransformer.num_heads,
            cfg.geotransformer.blocks,
            cfg.geotransformer.sigma_d,
            cfg.geotransformer.sigma_a,
            cfg.geotransformer.angle_k,
            cfg.geotransformer.sigma_gs,
            sigma_color=cfg.geotransformer.sigma_color,
            sigma_hd=cfg.geotransformer.sigma_hd,
            reduction_a=cfg.geotransformer.reduction_a
        )

        # end

        
        self.coarse_target = SuperPointTargetGenerator(  # 
            cfg.coarse_matching.num_targets, cfg.coarse_matching.overlap_threshold
        )

        # 
        #
        self.coarse_matching = SuperPointMatching(   #
            cfg.coarse_matching.num_correspondences, cfg.coarse_matching.dual_normalization
        )

        
        # 
        self.gaussian_coarse_matching = GaussianCoarseRegistration(cfg.num_correspondences)

        self.coarse_fusion_matching = CoarseMatchFusion(cfg.length)
        # 

       
        self.fine_matching = LocalGlobalRegistration(    # 
            cfg.fine_matching.topk,
            cfg.fine_matching.acceptance_radius,
            mutual=cfg.fine_matching.mutual,
            confidence_threshold=cfg.fine_matching.confidence_threshold,
            use_dustbin=cfg.fine_matching.use_dustbin,
            use_global_score=cfg.fine_matching.use_global_score,
            correspondence_threshold=cfg.fine_matching.correspondence_threshold,
            correspondence_limit=cfg.fine_matching.correspondence_limit,
            num_refinement_steps=cfg.fine_matching.num_refinement_steps,
        )

        # 
        self.optimal_transport = LearnableLogOptimalTransport(cfg.model.num_sinkhorn_iterations)  # 
    '''
    
    '''
    def forward(self, data_dict):
        output_dict = {}

        # Downsample point clouds   
        '''
        
        '''
        feats = data_dict['features'].detach()
        transform = data_dict['transform'].detach()   

        ref_length_c = data_dict['lengths'][-1][0].item()   # 
        ref_length_f = data_dict['lengths'][1][0].item()
        ref_length = data_dict['lengths'][0][0].item()
        points_c = data_dict['points'][-1].detach()
        points_f = data_dict['points'][1].detach()
        points = data_dict['points'][0].detach()

        ref_points_c = points_c[:ref_length_c]
        src_points_c = points_c[ref_length_c:]
        ref_points_f = points_f[:ref_length_f]
        src_points_f = points_f[ref_length_f:]
        ref_points = points[:ref_length]
        src_points = points[ref_length:]

        output_dict['ref_points_c'] = ref_points_c
        output_dict['src_points_c'] = src_points_c
        output_dict['ref_points_f'] = ref_points_f
        output_dict['src_points_f'] = src_points_f
        output_dict['ref_points'] = ref_points
        output_dict['src_points'] = src_points

        # 

        color_hsv_c = data_dict['hsv'][-1].detach()
        color_hsv_f = data_dict['hsv'][1].detach()
        color_hsv = data_dict['hsv'][0].detach()

        ref_hsv_c = color_hsv_c[:ref_length_c]
        src_hsv_c = color_hsv_c[ref_length_c:]
        ref_hsv_f = color_hsv_f[:ref_length_f]
        src_hsv_f = color_hsv_f[ref_length_f:]
        ref_hsv = color_hsv[:ref_length]
        src_hsv = color_hsv[ref_length:]
        output_dict['ref_hsv'] = ref_hsv
        output_dict['src_hsv'] = src_hsv
        output_dict['ref_hsv_c'] = ref_hsv_c
        output_dict['ref_hsv_f'] = ref_hsv_f
        output_dict['src_hsv_c'] = src_hsv_c
        output_dict['src_hav_f'] = src_hsv_f
        # end

        # 1. Generate ground truth node correspondences

        # 
        _, ref_node_masks, ref_node_knn_indices, ref_node_knn_masks = point_to_node_partition(
            ref_points_f, ref_points_c, self.num_points_in_patch
        )
        _, src_node_masks, src_node_knn_indices, src_node_knn_masks = point_to_node_partition(
            src_points_f, src_points_c, self.num_points_in_patch
        )

        #
        ref_padded_points_f = torch.cat([ref_points_f, torch.zeros_like(ref_points_f[:1])], dim=0)
        src_padded_points_f = torch.cat([src_points_f, torch.zeros_like(src_points_f[:1])], dim=0)
        ref_node_knn_points = index_select(ref_padded_points_f, ref_node_knn_indices, dim=0)
        src_node_knn_points = index_select(src_padded_points_f, src_node_knn_indices, dim=0)


        # 
        # print("ref_hsv_f shape: ", ref_hsv_f.shape)
        # print("src_hsv_f shape: ", src_hsv_f.shape)

        max_ref_index = ref_hsv_f.shape[0] - 1
        max_src_index = src_hsv_f.shape[0] - 1
        ref_node_knn_indices_1 = torch.clamp(ref_node_knn_indices, min=0, max=max_ref_index)
        src_node_knn_indices_1 = torch.clamp(src_node_knn_indices, min=0, max=max_src_index)

        # 
        ref_node_knn_colors = index_select(ref_hsv_f, ref_node_knn_indices_1, dim=0)  # [M, K, 3]
        src_node_knn_colors = index_select(src_hsv_f, src_node_knn_indices_1, dim=0)  # [M, K, 3]

        #
        # print("ref_node_knn_colors shape: ", ref_node_knn_colors.shape)  
        # print("src_node_knn_colors shape: ", src_node_knn_colors.shape) 

        # 
        ref_patches = torch.cat([ref_node_knn_points, ref_node_knn_colors], dim=-1)  # [M, K, 6]
        src_patches = torch.cat([src_node_knn_points, src_node_knn_colors], dim=-1)  # [M, K, 6]

        # 
        ref_valid_mask = ref_node_knn_masks & (ref_node_knn_indices_1 < ref_hsv_f.shape[0])  # [M, K]
        src_valid_mask = src_node_knn_masks & (src_node_knn_indices_1 < src_hsv_f.shape[0])  # [M, K]

        # 
        ref_valid_mask = ref_valid_mask.any(dim=-1)  # [M]
        src_valid_mask = src_valid_mask.any(dim=-1)  # [M]

        #
        if ref_valid_mask.dim() != 1:
            ref_valid_mask = ref_valid_mask.view(-1)
        if src_valid_mask.dim() != 1:
            src_valid_mask = src_valid_mask.view(-1)

        #
        ref_gs_params = self.gs_parametrization(ref_patches, ref_valid_mask)
        src_gs_params = self.gs_parametrization(src_patches, src_valid_mask)

        #
        output_dict['ref_gs_params'] = ref_gs_params
        output_dict['src_gs_params'] = src_gs_params


        # end

        # 
        gt_node_corr_indices, gt_node_corr_overlaps = get_node_correspondences(
            ref_points_c,
            src_points_c,
            ref_node_knn_points,
            src_node_knn_points,
            transform,
            self.matching_radius,
            ref_masks=ref_node_masks,
            src_masks=src_node_masks,
            ref_knn_masks=ref_node_knn_masks,
            src_knn_masks=src_node_knn_masks,
        )

        output_dict['gt_node_corr_indices'] = gt_node_corr_indices
        output_dict['gt_node_corr_overlaps'] = gt_node_corr_overlaps

        # 2. KPFCNN Encoder
        # 

        
        # Color encoder
        feats_list = self.backbone(feats, data_dict)

        feats_c = feats_list[-1]
        feats_f = feats_list[0]

        # 3. Conditional Transformer
        ref_feats_c = feats_c[:ref_length_c]
        src_feats_c = feats_c[ref_length_c:]

        # print("1: ", ref_points_c.shape, src_points_c.shape, ref_feats_c.shape, src_feats_c.shape)
        # print("1: ", ref_gs_params.shape, src_gs_params.shape)
        # print("2: ", ref_hsv_c.shape, src_hsv_c.shape)
        # print("3: ", ref_hsv_c.unsqueeze(0).shape, src_hsv_c.unsqueeze(0).shape)


        ref_feats_c, src_feats_c = self.transformer(
            ref_points_c.unsqueeze(0),
            src_points_c.unsqueeze(0),
            ref_feats_c.unsqueeze(0),
            src_feats_c.unsqueeze(0),
            ref_colors = ref_hsv_c.unsqueeze(0),
            src_colors = src_hsv_c.unsqueeze(0),
            ref_gs=output_dict['ref_gs_params'].unsqueeze(0),  # 3DGS parameters
            src_gs=output_dict['src_gs_params'].unsqueeze(0),  # 3DGS parameters
        )

        # end

        
        #
        ref_feats_c_norm = F.normalize(ref_feats_c.squeeze(0), p=2, dim=1)
        src_feats_c_norm = F.normalize(src_feats_c.squeeze(0), p=2, dim=1)

        output_dict['ref_feats_c'] = ref_feats_c_norm
        output_dict['src_feats_c'] = src_feats_c_norm

        # 5. Head for fine level matching  
        ref_feats_f = feats_f[:ref_length_f]
        src_feats_f = feats_f[ref_length_f:]
        output_dict['ref_feats_f'] = ref_feats_f
        output_dict['src_feats_f'] = src_feats_f

        # 6. Select topk nearest node correspondences

        with torch.no_grad():
            # 
            ref_node_corr_indices_1, src_node_corr_indices_1, node_corr_scores_1 = self.coarse_matching(
                ref_feats_c_norm, src_feats_c_norm, ref_node_masks, src_node_masks
            )
            ref_node_corr_indices_2, src_node_corr_indices_2, node_corr_scores_2 = self.gaussian_coarse_matching(
                output_dict['ref_gs_params'], output_dict['src_gs_params']
            )

            ref_node_corr_indices, src_node_corr_indices, node_corr_scores=self.coarse_fusion_matching(ref_node_corr_indices_1, src_node_corr_indices_1, ref_node_corr_indices_2, src_node_corr_indices_2,
                                        node_corr_scores_1, node_corr_scores_2)


            output_dict['ref_node_corr_indices'] = ref_node_corr_indices
            output_dict['src_node_corr_indices'] = src_node_corr_indices

       
            # 7 Random select ground truth node correspondences during 
            if self.training:
                ref_node_corr_indices, src_node_corr_indices, node_corr_scores = self.coarse_target(
                    gt_node_corr_indices, gt_node_corr_overlaps
                )

        
        # 7.2 Generate batched node points & feats
        ref_node_corr_knn_indices = ref_node_knn_indices[ref_node_corr_indices]  # (P, K)  ref_node_corr_knn_indices  src_node_corr_knn_indices 
        src_node_corr_knn_indices = src_node_knn_indices[src_node_corr_indices]  # (P, K)
        ref_node_corr_knn_masks = ref_node_knn_masks[ref_node_corr_indices]  # (P, K)  
        src_node_corr_knn_masks = src_node_knn_masks[src_node_corr_indices]  # (P, K)
        ref_node_corr_knn_points = ref_node_knn_points[ref_node_corr_indices]  # (P, K, 3)  
        src_node_corr_knn_points = src_node_knn_points[src_node_corr_indices]  # (P, K, 3)

        ref_padded_feats_f = torch.cat([ref_feats_f, torch.zeros_like(ref_feats_f[:1])], dim=0)   # 
        src_padded_feats_f = torch.cat([src_feats_f, torch.zeros_like(src_feats_f[:1])], dim=0)
        ref_node_corr_knn_feats = index_select(ref_padded_feats_f, ref_node_corr_knn_indices, dim=0)  # (P, K, C)  
        src_node_corr_knn_feats = index_select(src_padded_feats_f, src_node_corr_knn_indices, dim=0)  # (P, K, C)

        output_dict['ref_node_corr_knn_points'] = ref_node_corr_knn_points
        output_dict['src_node_corr_knn_points'] = src_node_corr_knn_points
        output_dict['ref_node_corr_knn_masks'] = ref_node_corr_knn_masks
        output_dict['src_node_corr_knn_masks'] = src_node_corr_knn_masks

        # 8. Optimal transport 
        matching_scores = torch.einsum('bnd,bmd->bnm', ref_node_corr_knn_feats, src_node_corr_knn_feats)  # (P, K, K)
        matching_scores = matching_scores / feats_f.shape[1] ** 0.5
        matching_scores = self.optimal_transport(matching_scores, ref_node_corr_knn_masks, src_node_corr_knn_masks)

        output_dict['matching_scores'] = matching_scores

        # 9. Generate final correspondences during testing
        with torch.no_grad():
            if not self.fine_matching.use_dustbin:
                matching_scores = matching_scores[:, :-1, :-1]

            # 
            ref_corr_points, src_corr_points, corr_scores, estimated_transform = self.fine_matching(
                ref_node_corr_knn_points,
                src_node_corr_knn_points,
                ref_node_corr_knn_masks,
                src_node_corr_knn_masks,
                matching_scores,
                node_corr_scores,
            )

            output_dict['ref_corr_points'] = ref_corr_points
            output_dict['src_corr_points'] = src_corr_points
            output_dict['corr_scores'] = corr_scores
            output_dict['estimated_transform'] = estimated_transform

        return output_dict


def create_model(config):
    model = GeoTransformer(config)
    return model


def main():
    from config import make_cfg

    cfg = make_cfg()
    model = create_model(cfg)
    print(model.state_dict().keys())
    print(model)

if __name__ == '__main__':
    main()
